Source code for hysop.tools.numerics

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
import gmpy2
from gmpy2 import mpq, mpz, mpfr, f2q

from hysop.constants import (
    HYSOP_REAL,
    HYSOP_INTEGER,
    HYSOP_INDEX,
    HYSOP_BOOL,
    HYSOP_COMPLEX,
)

MPQ = mpq(0).__class__
MPZ = mpz(0).__class__
MPFR = mpfr(0).__class__
F2Q = f2q(0).__class__


def _mpqize(x):
    if isinstance(x, int):
        return mpq(x, 1)
    elif isinstance(x, float):
        return f2q(x)
    else:
        return mpq(str(x))


mpqize = np.vectorize(_mpqize)


[docs] def get_dtype(x): if isinstance(x, np.dtype): return x.type elif hasattr(x, "dtype"): if callable(x.dtype): return x.dtype() elif ( x.dtype.__class__.__name__ == "getset_descriptor" ): # dtype.type has a dtype field... return x else: return x.dtype elif isinstance(x, int): return np.int64 elif isinstance(x, float): return np.float64 elif isinstance(x, complex): return np.complex128 elif x is None: return None else: msg = "Unknown type in get_dtype (got {})." msg = msg.format(x.__class__) raise TypeError(msg)
[docs] def get_itemsize(x): dtype = np.dtype(get_dtype(x)) return dtype.itemsize
[docs] def is_fp(x): types = (np.float16, np.float32, np.float64, np.longdouble) return get_dtype(x) in types
[docs] def is_signed(x): types = (np.int8, np.int16, np.int32, np.int64) return get_dtype(x) in types
[docs] def is_unsigned(x): types = (np.bool_, np.uint8, np.uint16, np.uint32, np.uint64) return get_dtype(x) in types
[docs] def is_integer(x): return is_signed(x) or is_unsigned(x)
[docs] def is_complex(x): types = (np.complex64, np.complex128, np.clongdouble) return get_dtype(x) in types
[docs] def default_invalid_value(dtype): nan = float("nan") if is_complex(dtype): return 1.0 * nan + 1.0j * nan elif is_fp(dtype): return nan elif is_unsigned(dtype): return 0 elif is_signed(dtype): return 0 else: raise NotImplementedError
# promote_dtype
[docs] def match_dtype(x, dtype): """Promote x.dtype to dtype (always safe cast).""" xtype = get_dtype(x) if isinstance(dtype, str): if dtype == "f": return np.promote_types(xtype, np.float16) elif dtype == "i": return np.promote_types(xtype, np.int8) elif dtype == "u": return np.promote_types(xtype, np.uint8) elif dtype == "b": return np.promote_types(xtype, HYSOP_BOOL) elif dtype == "c": return np.promote_types(xtype, np.complex64) else: raise NotImplementedError(dtype) elif xtype is None: return dtype elif dtype is None: return xtype else: return dtype
[docs] def demote_dtype(x, dtype): """Demote x.dtype to dtype (not a safe cast).""" xtype = get_dtype(x) n = xtype(0).itemsize if is_complex(xtype): n //= 2 if isinstance(dtype, str): if dtype == "c": return { 1: np.complex64, 2: np.complex64, 4: np.complex64, 8: np.complex128, 16: np.clongdouble, }[n] elif dtype == "f": return { 1: np.float16, 2: np.float16, 4: np.float32, 8: np.float64, 16: np.longdouble, }[n] elif dtype == "i": return {1: np.int8, 2: np.int16, 4: np.int32, 8: np.int64}[n] elif dtype == "u": return {1: np.uint8, 2: np.uint16, 4: np.uint32, 8: np.uint64}[n] else: raise NotImplementedError(dtype) elif xtype is None: return dtype elif dtype is None: return xtype else: return dtype
[docs] def match_float_type(x): return match_dtype(x, "f")
[docs] def match_signed_type(x): return match_dtype(x, "i")
[docs] def match_unsigned_type(x): return match_dtype(x, "i")
[docs] def match_complex_type(x): return match_dtype(x, "c")
[docs] def match_bool_type(x): return match_dtype(x, "b")
[docs] def complex_to_float_dtype(dtype): dtype = get_dtype(dtype) assert is_complex(dtype) if dtype == np.complex64: return np.float32 elif dtype == np.complex128: return np.float64 elif dtype == np.clongdouble: return np.longdouble else: msg = msg.format(dtype) msg = "Unknown complex type {}." raise RuntimeError(msg)
[docs] def float_to_complex_dtype(dtype): dtype = get_dtype(dtype) assert is_fp(dtype), f"{dtype} is not a float" if dtype == np.float32: return np.complex64 elif dtype == np.float64: return np.complex128 elif dtype == np.longdouble: return np.clongdouble else: msg = "Unknown float type {}." msg = msg.format(dtype) raise RuntimeError(msg)
[docs] def determine_fp_types(dtype): if is_fp(dtype): ftype = dtype ctype = float_to_complex_dtype(ftype) elif is_complex(dtype): ctype = dtype ftype = complex_to_float_dtype(ctype) else: msg = "{} is not a floating point or complex data type." msg = msg.format(dtype) raise ValueError(msg) return (np.dtype(ftype), np.dtype(ctype))
[docs] def find_common_dtype(*args): dtypes = tuple(get_dtype(arg) for arg in args) itemsize = tuple(get_itemsize(x) for x in dtypes) n = max(itemsize) if any(is_complex(x) for x in dtypes): return {8: np.complex64, 16: np.complex128, 32: np.clongdouble}[n] elif any(is_fp(x) for x in dtypes): return {2: np.float16, 4: np.float32, 8: np.float64, 16: np.longdouble}[n] elif any(is_signed(x) for x in dtypes): return {1: np.int8, 2: np.int16, 4: np.int32, 8: np.int64}[n] elif any(is_unsigned(x) for x in dtypes): return {1: np.uint8, 2: np.uint16, 4: np.uint32, 8: np.uint64}[n] else: msg = "Did not find any matching dtype." raise NotImplementedError(msg)